// Copyright 2020 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

mod error;

pub(crate) mod pool;
mod supported_protocols;

use std::{
    collections::{HashMap, HashSet},
    fmt,
    fmt::{Display, Formatter},
    future::Future,
    io, mem,
    pin::Pin,
    sync::atomic::{AtomicUsize, Ordering},
    task::{Context, Poll, Waker},
    time::Duration,
};

pub use error::ConnectionError;
pub(crate) use error::{PendingInboundConnectionError, PendingOutboundConnectionError};
use futures::{future::BoxFuture, stream, stream::FuturesUnordered, FutureExt, StreamExt};
use futures_timer::Delay;
use libp2p_core::{
    connection::ConnectedPoint,
    multiaddr::Multiaddr,
    muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox},
    transport::PortUse,
    upgrade,
    upgrade::{NegotiationError, ProtocolError},
    Endpoint,
};
use libp2p_identity::PeerId;
pub use supported_protocols::SupportedProtocols;
use web_time::Instant;

use crate::{
    handler::{
        AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError,
        FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport,
        ProtocolsChange, UpgradeInfoSend,
    },
    stream::ActiveStreamCounter,
    upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
    ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
};

static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);

/// Connection identifier.
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ConnectionId(usize);

impl ConnectionId {
    /// Creates an _unchecked_ [`ConnectionId`].
    ///
    /// [`Swarm`](crate::Swarm) enforces that [`ConnectionId`]s are unique and not reused.
    /// This constructor does not, hence the _unchecked_.
    ///
    /// It is primarily meant for allowing manual tests of
    /// [`NetworkBehaviour`](crate::NetworkBehaviour)s.
    pub fn new_unchecked(id: usize) -> Self {
        Self(id)
    }

    /// Returns the next available [`ConnectionId`].
    pub(crate) fn next() -> Self {
        Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
    }
}

impl Display for ConnectionId {
    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
        write!(f, "{}", self.0)
    }
}

/// Information about a successfully established connection.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Connected {
    /// The connected endpoint, including network address information.
    pub(crate) endpoint: ConnectedPoint,
    /// Information obtained from the transport.
    pub(crate) peer_id: PeerId,
}

/// Event generated by a [`Connection`].
#[derive(Debug, Clone)]
pub(crate) enum Event<T> {
    /// Event generated by the [`ConnectionHandler`].
    Handler(T),
    /// Address of the remote has changed.
    AddressChange(Multiaddr),
}

/// A multiplexed connection to a peer with an associated [`ConnectionHandler`].
pub(crate) struct Connection<THandler>
where
    THandler: ConnectionHandler,
{
    /// Node that handles the muxing.
    muxing: StreamMuxerBox,
    /// The underlying handler.
    handler: THandler,
    /// Futures that upgrade incoming substreams.
    negotiating_in: FuturesUnordered<
        StreamUpgrade<
            THandler::InboundOpenInfo,
            <THandler::InboundProtocol as InboundUpgradeSend>::Output,
            <THandler::InboundProtocol as InboundUpgradeSend>::Error,
        >,
    >,
    /// Futures that upgrade outgoing substreams.
    negotiating_out: FuturesUnordered<
        StreamUpgrade<
            THandler::OutboundOpenInfo,
            <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
            <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
        >,
    >,
    /// The currently planned connection & handler shutdown.
    shutdown: Shutdown,
    /// The substream upgrade protocol override, if any.
    substream_upgrade_protocol_override: Option<upgrade::Version>,
    /// The maximum number of inbound streams concurrently negotiating on a
    /// connection. New inbound streams exceeding the limit are dropped and thus
    /// reset.
    ///
    /// Note: This only enforces a limit on the number of concurrently
    /// negotiating inbound streams. The total number of inbound streams on a
    /// connection is the sum of negotiating and negotiated streams. A limit on
    /// the total number of streams can be enforced at the [`StreamMuxerBox`] level.
    max_negotiating_inbound_streams: usize,
    /// Contains all upgrades that are waiting for a new outbound substream.
    ///
    /// The upgrade timeout is already ticking here so this may fail in case the remote is not
    /// quick enough in providing us with a new stream.
    requested_substreams: FuturesUnordered<
        SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
    >,

    local_supported_protocols:
        HashMap<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>, bool>,
    remote_supported_protocols: HashSet<StreamProtocol>,
    protocol_buffer: Vec<StreamProtocol>,

    idle_timeout: Duration,
    stream_counter: ActiveStreamCounter,
}

impl<THandler> fmt::Debug for Connection<THandler>
where
    THandler: ConnectionHandler + fmt::Debug,
    THandler::OutboundOpenInfo: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Connection")
            .field("handler", &self.handler)
            .finish()
    }
}

impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}

impl<THandler> Connection<THandler>
where
    THandler: ConnectionHandler,
{
    /// Builds a new `Connection` from the given substream multiplexer
    /// and connection handler.
    pub(crate) fn new(
        muxer: StreamMuxerBox,
        mut handler: THandler,
        substream_upgrade_protocol_override: Option<upgrade::Version>,
        max_negotiating_inbound_streams: usize,
        idle_timeout: Duration,
    ) -> Self {
        let initial_protocols = gather_supported_protocols(&handler);
        let mut buffer = Vec::new();

        if !initial_protocols.is_empty() {
            handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
                ProtocolsChange::from_initial_protocols(
                    initial_protocols.keys().map(|e| &e.0),
                    &mut buffer,
                ),
            ));
        }

        Connection {
            muxing: muxer,
            handler,
            negotiating_in: Default::default(),
            negotiating_out: Default::default(),
            shutdown: Shutdown::None,
            substream_upgrade_protocol_override,
            max_negotiating_inbound_streams,
            requested_substreams: Default::default(),
            local_supported_protocols: initial_protocols,
            remote_supported_protocols: Default::default(),
            protocol_buffer: buffer,
            idle_timeout,
            stream_counter: ActiveStreamCounter::default(),
        }
    }

    /// Notifies the connection handler of an event.
    pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
        self.handler.on_behaviour_event(event);
    }

    /// Begins an orderly shutdown of the connection, returning a stream of final events and a
    /// `Future` that resolves when connection shutdown is complete.
    pub(crate) fn close(
        self,
    ) -> (
        impl futures::Stream<Item = THandler::ToBehaviour>,
        impl Future<Output = io::Result<()>>,
    ) {
        let Connection {
            mut handler,
            muxing,
            ..
        } = self;

        (
            stream::poll_fn(move |cx| handler.poll_close(cx)),
            muxing.close(),
        )
    }

    /// Polls the handler and the substream, forwarding events from the former to the latter and
    /// vice versa.
    #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
    pub(crate) fn poll(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
        let Self {
            requested_substreams,
            muxing,
            handler,
            negotiating_out,
            negotiating_in,
            shutdown,
            max_negotiating_inbound_streams,
            substream_upgrade_protocol_override,
            local_supported_protocols: supported_protocols,
            remote_supported_protocols,
            protocol_buffer,
            idle_timeout,
            stream_counter,
            ..
        } = self.get_mut();

        loop {
            match requested_substreams.poll_next_unpin(cx) {
                Poll::Ready(Some(Ok(()))) => continue,
                Poll::Ready(Some(Err(info))) => {
                    handler.on_connection_event(ConnectionEvent::DialUpgradeError(
                        DialUpgradeError {
                            info,
                            error: StreamUpgradeError::Timeout,
                        },
                    ));
                    continue;
                }
                Poll::Ready(None) | Poll::Pending => {}
            }

            // Poll the [`ConnectionHandler`].
            match handler.poll(cx) {
                Poll::Pending => {}
                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
                    let timeout = *protocol.timeout();
                    let (upgrade, user_data) = protocol.into_upgrade();

                    requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
                    continue; // Poll handler until exhausted.
                }
                Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
                    return Poll::Ready(Ok(Event::Handler(event)));
                }
                Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
                    ProtocolSupport::Added(protocols),
                )) => {
                    if let Some(added) =
                        ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer)
                    {
                        handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
                        remote_supported_protocols.extend(protocol_buffer.drain(..));
                    }
                    continue;
                }
                Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
                    ProtocolSupport::Removed(protocols),
                )) => {
                    if let Some(removed) = ProtocolsChange::remove(
                        remote_supported_protocols,
                        protocols,
                        protocol_buffer,
                    ) {
                        handler
                            .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
                    }
                    continue;
                }
            }

            // In case the [`ConnectionHandler`] can not make any more progress, poll the
            // negotiating outbound streams.
            match negotiating_out.poll_next_unpin(cx) {
                Poll::Pending | Poll::Ready(None) => {}
                Poll::Ready(Some((info, Ok(protocol)))) => {
                    handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
                        FullyNegotiatedOutbound { protocol, info },
                    ));
                    continue;
                }
                Poll::Ready(Some((info, Err(error)))) => {
                    handler.on_connection_event(ConnectionEvent::DialUpgradeError(
                        DialUpgradeError { info, error },
                    ));
                    continue;
                }
            }

            // In case both the [`ConnectionHandler`] and the negotiating outbound streams can not
            // make any more progress, poll the negotiating inbound streams.
            match negotiating_in.poll_next_unpin(cx) {
                Poll::Pending | Poll::Ready(None) => {}
                Poll::Ready(Some((info, Ok(protocol)))) => {
                    handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
                        FullyNegotiatedInbound { protocol, info },
                    ));
                    continue;
                }
                Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
                    handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
                        ListenUpgradeError { info, error },
                    ));
                    continue;
                }
                Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
                    tracing::debug!("failed to upgrade inbound stream: {e}");
                    continue;
                }
                Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
                    tracing::debug!("no protocol could be agreed upon for inbound stream");
                    continue;
                }
                Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
                    tracing::debug!("inbound stream upgrade timed out");
                    continue;
                }
            }

            // Check if the connection (and handler) should be shut down.
            // As long as we're still negotiating substreams or have
            // any active streams shutdown is always postponed.
            if negotiating_in.is_empty()
                && negotiating_out.is_empty()
                && requested_substreams.is_empty()
                && stream_counter.has_no_active_streams()
            {
                if let Some(new_timeout) =
                    compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
                {
                    *shutdown = new_timeout;
                }

                match shutdown {
                    Shutdown::None => {}
                    Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
                    Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
                        Poll::Ready(_) => {
                            return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
                        }
                        Poll::Pending => {}
                    },
                }
            } else {
                *shutdown = Shutdown::None;
            }

            match muxing.poll_unpin(cx)? {
                Poll::Pending => {}
                Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
                    handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
                        new_address: &address,
                    }));
                    return Poll::Ready(Ok(Event::AddressChange(address)));
                }
            }

            if let Some(requested_substream) = requested_substreams.iter_mut().next() {
                match muxing.poll_outbound_unpin(cx)? {
                    Poll::Pending => {}
                    Poll::Ready(substream) => {
                        let (user_data, timeout, upgrade) = requested_substream.extract();

                        negotiating_out.push(StreamUpgrade::new_outbound(
                            substream,
                            user_data,
                            timeout,
                            upgrade,
                            *substream_upgrade_protocol_override,
                            stream_counter.clone(),
                        ));

                        // Go back to the top,
                        // handler can potentially make progress again.
                        continue;
                    }
                }
            }

            if negotiating_in.len() < *max_negotiating_inbound_streams {
                match muxing.poll_inbound_unpin(cx)? {
                    Poll::Pending => {}
                    Poll::Ready(substream) => {
                        let protocol = handler.listen_protocol();

                        negotiating_in.push(StreamUpgrade::new_inbound(
                            substream,
                            protocol,
                            stream_counter.clone(),
                        ));

                        // Go back to the top,
                        // handler can potentially make progress again.
                        continue;
                    }
                }
            }

            let changes = ProtocolsChange::from_full_sets(
                supported_protocols,
                handler.listen_protocol().upgrade().protocol_info(),
                protocol_buffer,
            );

            if !changes.is_empty() {
                for change in changes {
                    handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
                }
                // Go back to the top, handler can potentially make progress again.
                continue;
            }

            // Nothing can make progress, return `Pending`.
            return Poll::Pending;
        }
    }

    #[cfg(test)]
    fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
        Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
    }
}

fn gather_supported_protocols<C: ConnectionHandler>(
    handler: &C,
) -> HashMap<AsStrHashEq<<C::InboundProtocol as UpgradeInfoSend>::Info>, bool> {
    handler
        .listen_protocol()
        .upgrade()
        .protocol_info()
        .map(|info| (AsStrHashEq(info), true))
        .collect()
}

fn compute_new_shutdown(
    handler_keep_alive: bool,
    current_shutdown: &Shutdown,
    idle_timeout: Duration,
) -> Option<Shutdown> {
    match (current_shutdown, handler_keep_alive) {
        (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
        // Do nothing, i.e. let the shutdown timer continue to tick.
        (Shutdown::Later(_), false) => None,
        (_, false) => {
            let now = Instant::now();
            let safe_keep_alive = checked_add_fraction(now, idle_timeout);

            Some(Shutdown::Later(Delay::new(safe_keep_alive)))
        }
        (_, true) => Some(Shutdown::None),
    }
}

/// Repeatedly halves and adds the [`Duration`]
/// to the [`Instant`] until [`Instant::checked_add`] succeeds.
///
/// [`Instant`] depends on the underlying platform and has a limit of which points in time it can
/// represent. The [`Duration`] computed by the this function may not be the longest possible that
/// we can add to `now` but it will work.
fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
    while start.checked_add(duration).is_none() {
        tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");

        duration /= 2;
    }

    duration
}

/// Borrowed information about an incoming connection currently being negotiated.
#[derive(Debug, Copy, Clone)]
pub(crate) struct IncomingInfo<'a> {
    /// Local connection address.
    pub(crate) local_addr: &'a Multiaddr,
    /// Address used to send back data to the remote.
    pub(crate) send_back_addr: &'a Multiaddr,
}

impl IncomingInfo<'_> {
    /// Builds the [`ConnectedPoint`] corresponding to the incoming connection.
    pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
        ConnectedPoint::Listener {
            local_addr: self.local_addr.clone(),
            send_back_addr: self.send_back_addr.clone(),
        }
    }
}

struct StreamUpgrade<UserData, TOk, TErr> {
    user_data: Option<UserData>,
    timeout: Delay,
    upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
}

impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
    fn new_outbound<Upgrade>(
        substream: SubstreamBox,
        user_data: UserData,
        timeout: Delay,
        upgrade: Upgrade,
        version_override: Option<upgrade::Version>,
        counter: ActiveStreamCounter,
    ) -> Self
    where
        Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
    {
        let effective_version = match version_override {
            Some(version_override) if version_override != upgrade::Version::default() => {
                tracing::debug!(
                    "Substream upgrade protocol override: {:?} -> {:?}",
                    upgrade::Version::default(),
                    version_override
                );

                version_override
            }
            _ => upgrade::Version::default(),
        };
        let protocols = upgrade.protocol_info();

        Self {
            user_data: Some(user_data),
            timeout,
            upgrade: Box::pin(async move {
                let (info, stream) = multistream_select::dialer_select_proto(
                    substream,
                    protocols,
                    effective_version,
                )
                .await
                .map_err(to_stream_upgrade_error)?;

                let output = upgrade
                    .upgrade_outbound(Stream::new(stream, counter), info)
                    .await
                    .map_err(StreamUpgradeError::Apply)?;

                Ok(output)
            }),
        }
    }
}

impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
    fn new_inbound<Upgrade>(
        substream: SubstreamBox,
        protocol: SubstreamProtocol<Upgrade, UserData>,
        counter: ActiveStreamCounter,
    ) -> Self
    where
        Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
    {
        let timeout = *protocol.timeout();
        let (upgrade, open_info) = protocol.into_upgrade();
        let protocols = upgrade.protocol_info();

        Self {
            user_data: Some(open_info),
            timeout: Delay::new(timeout),
            upgrade: Box::pin(async move {
                let (info, stream) =
                    multistream_select::listener_select_proto(substream, protocols)
                        .await
                        .map_err(to_stream_upgrade_error)?;

                let output = upgrade
                    .upgrade_inbound(Stream::new(stream, counter), info)
                    .await
                    .map_err(StreamUpgradeError::Apply)?;

                Ok(output)
            }),
        }
    }
}

fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
    match e {
        NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
        NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
        NegotiationError::ProtocolError(other) => StreamUpgradeError::Io(io::Error::other(other)),
    }
}

impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}

impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
    type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
        match self.timeout.poll_unpin(cx) {
            Poll::Ready(()) => {
                return Poll::Ready((
                    self.user_data
                        .take()
                        .expect("Future not to be polled again once ready."),
                    Err(StreamUpgradeError::Timeout),
                ))
            }

            Poll::Pending => {}
        }

        let result = futures::ready!(self.upgrade.poll_unpin(cx));
        let user_data = self
            .user_data
            .take()
            .expect("Future not to be polled again once ready.");

        Poll::Ready((user_data, result))
    }
}

enum SubstreamRequested<UserData, Upgrade> {
    Waiting {
        user_data: UserData,
        timeout: Delay,
        upgrade: Upgrade,
        /// A waker to notify our [`FuturesUnordered`] that we have extracted the data.
        ///
        /// This will ensure that we will get polled again in the next iteration which allows us to
        /// resolve with `Ok(())` and be removed from the [`FuturesUnordered`].
        extracted_waker: Option<Waker>,
    },
    Done,
}

impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
    fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
        Self::Waiting {
            user_data,
            timeout: Delay::new(timeout),
            upgrade,
            extracted_waker: None,
        }
    }

    fn extract(&mut self) -> (UserData, Delay, Upgrade) {
        match mem::replace(self, Self::Done) {
            SubstreamRequested::Waiting {
                user_data,
                timeout,
                upgrade,
                extracted_waker: waker,
            } => {
                if let Some(waker) = waker {
                    waker.wake();
                }

                (user_data, timeout, upgrade)
            }
            SubstreamRequested::Done => panic!("cannot extract twice"),
        }
    }
}

impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}

impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
    type Output = Result<(), UserData>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.get_mut();

        match mem::replace(this, Self::Done) {
            SubstreamRequested::Waiting {
                user_data,
                upgrade,
                mut timeout,
                ..
            } => match timeout.poll_unpin(cx) {
                Poll::Ready(()) => Poll::Ready(Err(user_data)),
                Poll::Pending => {
                    *this = Self::Waiting {
                        user_data,
                        upgrade,
                        timeout,
                        extracted_waker: Some(cx.waker().clone()),
                    };
                    Poll::Pending
                }
            },
            SubstreamRequested::Done => Poll::Ready(Ok(())),
        }
    }
}

/// The options for a planned connection & handler shutdown.
///
/// A shutdown is planned anew based on the return value of
/// [`ConnectionHandler::connection_keep_alive`] of the underlying handler
/// after every invocation of [`ConnectionHandler::poll`].
///
/// A planned shutdown is always postponed for as long as there are ingoing
/// or outgoing substreams being negotiated, i.e. it is a graceful, "idle"
/// shutdown.
#[derive(Debug)]
enum Shutdown {
    /// No shutdown is planned.
    None,
    /// A shut down is planned as soon as possible.
    Asap,
    /// A shut down is planned for when a `Delay` has elapsed.
    Later(Delay),
}

// Structure used to avoid allocations when storing the protocols in the `HashMap`.
// Instead of allocating a new `String` for the key,
// we use `T::as_ref()` in `Hash`, `Eq` and `PartialEq` requirements.
pub(crate) struct AsStrHashEq<T>(pub(crate) T);

impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}

impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
    fn eq(&self, other: &Self) -> bool {
        self.0.as_ref() == other.0.as_ref()
    }
}

impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        self.0.as_ref().hash(state)
    }
}

#[cfg(test)]
mod tests {
    use std::{
        convert::Infallible,
        sync::{Arc, Weak},
        time::Instant,
    };

    use futures::{future, AsyncRead, AsyncWrite};
    use libp2p_core::{
        upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
        StreamMuxer,
    };
    use quickcheck::*;
    use tracing_subscriber::EnvFilter;

    use super::*;
    use crate::dummy;

    #[test]
    fn max_negotiating_inbound_streams() {
        let _ = tracing_subscriber::fmt()
            .with_env_filter(EnvFilter::from_default_env())
            .try_init();

        fn prop(max_negotiating_inbound_streams: u8) {
            let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();

            let alive_substream_counter = Arc::new(());
            let mut connection = Connection::new(
                StreamMuxerBox::new(DummyStreamMuxer {
                    counter: alive_substream_counter.clone(),
                }),
                MockConnectionHandler::new(Duration::from_secs(10)),
                None,
                max_negotiating_inbound_streams,
                Duration::ZERO,
            );

            let result = connection.poll_noop_waker();

            assert!(result.is_pending());
            assert_eq!(
                Arc::weak_count(&alive_substream_counter),
                max_negotiating_inbound_streams,
                "Expect no more than the maximum number of allowed streams"
            );
        }

        QuickCheck::new().quickcheck(prop as fn(_));
    }

    #[test]
    fn outbound_stream_timeout_starts_on_request() {
        let upgrade_timeout = Duration::from_secs(1);
        let mut connection = Connection::new(
            StreamMuxerBox::new(PendingStreamMuxer),
            MockConnectionHandler::new(upgrade_timeout),
            None,
            2,
            Duration::ZERO,
        );

        connection.handler.open_new_outbound();
        let _ = connection.poll_noop_waker();

        std::thread::sleep(upgrade_timeout + Duration::from_secs(1));

        let _ = connection.poll_noop_waker();

        assert!(matches!(
            connection.handler.error.unwrap(),
            StreamUpgradeError::Timeout
        ))
    }

    #[test]
    fn propagates_changes_to_supported_inbound_protocols() {
        let mut connection = Connection::new(
            StreamMuxerBox::new(PendingStreamMuxer),
            ConfigurableProtocolConnectionHandler::default(),
            None,
            0,
            Duration::ZERO,
        );

        // First, start listening on a single protocol.
        connection.handler.listen_on(&["/foo"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
        assert!(connection.handler.local_removed.is_empty());

        // Second, listen on two protocols.
        connection.handler.listen_on(&["/foo", "/bar"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(
            connection.handler.local_added,
            vec![vec!["/foo"], vec!["/bar"]],
            "expect to only receive an event for the newly added protocols"
        );
        assert!(connection.handler.local_removed.is_empty());

        // Third, stop listening on the first protocol.
        connection.handler.listen_on(&["/bar"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(
            connection.handler.local_added,
            vec![vec!["/foo"], vec!["/bar"]]
        );
        assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
    }

    #[test]
    fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
        let mut connection = Connection::new(
            StreamMuxerBox::new(PendingStreamMuxer),
            ConfigurableProtocolConnectionHandler::default(),
            None,
            0,
            Duration::ZERO,
        );

        // First, remote supports a single protocol.
        connection.handler.remote_adds_support_for(&["/foo"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
        assert!(connection.handler.remote_removed.is_empty());

        // Second, it adds a protocol but also still includes the first one.
        connection
            .handler
            .remote_adds_support_for(&["/foo", "/bar"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(
            connection.handler.remote_added,
            vec![vec!["/foo"], vec!["/bar"]],
            "expect to only receive an event for the newly added protocol"
        );
        assert!(connection.handler.remote_removed.is_empty());

        // Third, stop listening on a protocol it never advertised (we can't control what handlers
        // do so this needs to be handled gracefully).
        connection.handler.remote_removes_support_for(&["/baz"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(
            connection.handler.remote_added,
            vec![vec!["/foo"], vec!["/bar"]]
        );
        assert!(&connection.handler.remote_removed.is_empty());

        // Fourth, stop listening on a protocol that was previously supported
        connection.handler.remote_removes_support_for(&["/bar"]);
        let _ = connection.poll_noop_waker();

        assert_eq!(
            connection.handler.remote_added,
            vec![vec!["/foo"], vec!["/bar"]]
        );
        assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
    }

    #[tokio::test]
    async fn idle_timeout_with_keep_alive_no() {
        let idle_timeout = Duration::from_millis(100);

        let mut connection = Connection::new(
            StreamMuxerBox::new(PendingStreamMuxer),
            dummy::ConnectionHandler,
            None,
            0,
            idle_timeout,
        );

        assert!(connection.poll_noop_waker().is_pending());

        tokio::time::sleep(idle_timeout).await;

        assert!(matches!(
            connection.poll_noop_waker(),
            Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
        ));
    }

    #[test]
    fn checked_add_fraction_can_add_u64_max() {
        let _ = tracing_subscriber::fmt()
            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
            .try_init();
        let start = Instant::now();

        let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));

        assert!(start.checked_add(duration).is_some())
    }

    #[test]
    fn compute_new_shutdown_does_not_panic() {
        let _ = tracing_subscriber::fmt()
            .with_env_filter(EnvFilter::from_default_env())
            .try_init();

        #[derive(Debug)]
        struct ArbitraryShutdown(Shutdown);

        impl Clone for ArbitraryShutdown {
            fn clone(&self) -> Self {
                let shutdown = match self.0 {
                    Shutdown::None => Shutdown::None,
                    Shutdown::Asap => Shutdown::Asap,
                    Shutdown::Later(_) => Shutdown::Later(
                        // compute_new_shutdown does not touch the delay. Delay does not
                        // implement Clone. Thus use a placeholder delay.
                        Delay::new(Duration::from_secs(1)),
                    ),
                };

                ArbitraryShutdown(shutdown)
            }
        }

        impl Arbitrary for ArbitraryShutdown {
            fn arbitrary(g: &mut Gen) -> Self {
                let shutdown = match g.gen_range(1u8..4) {
                    1 => Shutdown::None,
                    2 => Shutdown::Asap,
                    3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
                    _ => unreachable!(),
                };

                Self(shutdown)
            }
        }

        fn prop(
            handler_keep_alive: bool,
            current_shutdown: ArbitraryShutdown,
            idle_timeout: Duration,
        ) {
            compute_new_shutdown(handler_keep_alive, &current_shutdown.0, idle_timeout);
        }

        QuickCheck::new().quickcheck(prop as fn(_, _, _));
    }

    struct DummyStreamMuxer {
        counter: Arc<()>,
    }

    impl StreamMuxer for DummyStreamMuxer {
        type Substream = PendingSubstream;
        type Error = Infallible;

        fn poll_inbound(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
        ) -> Poll<Result<Self::Substream, Self::Error>> {
            Poll::Ready(Ok(PendingSubstream {
                _weak: Arc::downgrade(&self.counter),
            }))
        }

        fn poll_outbound(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
        ) -> Poll<Result<Self::Substream, Self::Error>> {
            Poll::Pending
        }

        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
            Poll::Ready(Ok(()))
        }

        fn poll(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
        ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
            Poll::Pending
        }
    }

    /// A [`StreamMuxer`] which never returns a stream.
    struct PendingStreamMuxer;

    impl StreamMuxer for PendingStreamMuxer {
        type Substream = PendingSubstream;
        type Error = Infallible;

        fn poll_inbound(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
        ) -> Poll<Result<Self::Substream, Self::Error>> {
            Poll::Pending
        }

        fn poll_outbound(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
        ) -> Poll<Result<Self::Substream, Self::Error>> {
            Poll::Pending
        }

        fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
            Poll::Pending
        }

        fn poll(
            self: Pin<&mut Self>,
            _: &mut Context<'_>,
        ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
            Poll::Pending
        }
    }

    struct PendingSubstream {
        _weak: Weak<()>,
    }

    impl AsyncRead for PendingSubstream {
        fn poll_read(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &mut [u8],
        ) -> Poll<std::io::Result<usize>> {
            Poll::Pending
        }
    }

    impl AsyncWrite for PendingSubstream {
        fn poll_write(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &[u8],
        ) -> Poll<std::io::Result<usize>> {
            Poll::Pending
        }

        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
            Poll::Pending
        }

        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
            Poll::Pending
        }
    }

    struct MockConnectionHandler {
        outbound_requested: bool,
        error: Option<StreamUpgradeError<Infallible>>,
        upgrade_timeout: Duration,
    }

    impl MockConnectionHandler {
        fn new(upgrade_timeout: Duration) -> Self {
            Self {
                outbound_requested: false,
                error: None,
                upgrade_timeout,
            }
        }

        fn open_new_outbound(&mut self) {
            self.outbound_requested = true;
        }
    }

    #[derive(Default)]
    struct ConfigurableProtocolConnectionHandler {
        events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Infallible>>,
        active_protocols: HashSet<StreamProtocol>,
        local_added: Vec<Vec<StreamProtocol>>,
        local_removed: Vec<Vec<StreamProtocol>>,
        remote_added: Vec<Vec<StreamProtocol>>,
        remote_removed: Vec<Vec<StreamProtocol>>,
    }

    impl ConfigurableProtocolConnectionHandler {
        fn listen_on(&mut self, protocols: &[&'static str]) {
            self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
        }

        fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
            self.events
                .push(ConnectionHandlerEvent::ReportRemoteProtocols(
                    ProtocolSupport::Added(
                        protocols.iter().copied().map(StreamProtocol::new).collect(),
                    ),
                ));
        }

        fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
            self.events
                .push(ConnectionHandlerEvent::ReportRemoteProtocols(
                    ProtocolSupport::Removed(
                        protocols.iter().copied().map(StreamProtocol::new).collect(),
                    ),
                ));
        }
    }

    impl ConnectionHandler for MockConnectionHandler {
        type FromBehaviour = Infallible;
        type ToBehaviour = Infallible;
        type InboundProtocol = DeniedUpgrade;
        type OutboundProtocol = DeniedUpgrade;
        type InboundOpenInfo = ();
        type OutboundOpenInfo = ();

        fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
            SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
        }

        fn on_connection_event(
            &mut self,
            event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
        ) {
            match event {
                ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
                    protocol,
                    ..
                }) => libp2p_core::util::unreachable(protocol),
                ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
                    protocol,
                    ..
                }) => libp2p_core::util::unreachable(protocol),
                ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
                    self.error = Some(error)
                }
                ConnectionEvent::AddressChange(_)
                | ConnectionEvent::ListenUpgradeError(_)
                | ConnectionEvent::LocalProtocolsChange(_)
                | ConnectionEvent::RemoteProtocolsChange(_) => {}
            }
        }

        fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
            libp2p_core::util::unreachable(event)
        }

        fn connection_keep_alive(&self) -> bool {
            true
        }

        fn poll(
            &mut self,
            _: &mut Context<'_>,
        ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
            if self.outbound_requested {
                self.outbound_requested = false;
                return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
                    protocol: SubstreamProtocol::new(DeniedUpgrade, ())
                        .with_timeout(self.upgrade_timeout),
                });
            }

            Poll::Pending
        }
    }

    impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
        type FromBehaviour = Infallible;
        type ToBehaviour = Infallible;
        type InboundProtocol = ManyProtocolsUpgrade;
        type OutboundProtocol = DeniedUpgrade;
        type InboundOpenInfo = ();
        type OutboundOpenInfo = ();

        fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
            SubstreamProtocol::new(
                ManyProtocolsUpgrade {
                    protocols: Vec::from_iter(self.active_protocols.clone()),
                },
                (),
            )
        }

        fn on_connection_event(
            &mut self,
            event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
        ) {
            match event {
                ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
                    self.local_added.push(added.cloned().collect())
                }
                ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
                    self.local_removed.push(removed.cloned().collect())
                }
                ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
                    self.remote_added.push(added.cloned().collect())
                }
                ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
                    self.remote_removed.push(removed.cloned().collect())
                }
                _ => {}
            }
        }

        fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
            libp2p_core::util::unreachable(event)
        }

        fn connection_keep_alive(&self) -> bool {
            true
        }

        fn poll(
            &mut self,
            _: &mut Context<'_>,
        ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
            if let Some(event) = self.events.pop() {
                return Poll::Ready(event);
            }

            Poll::Pending
        }
    }

    struct ManyProtocolsUpgrade {
        protocols: Vec<StreamProtocol>,
    }

    impl UpgradeInfo for ManyProtocolsUpgrade {
        type Info = StreamProtocol;
        type InfoIter = std::vec::IntoIter<Self::Info>;

        fn protocol_info(&self) -> Self::InfoIter {
            self.protocols.clone().into_iter()
        }
    }

    impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
        type Output = C;
        type Error = Infallible;
        type Future = future::Ready<Result<Self::Output, Self::Error>>;

        fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
            future::ready(Ok(stream))
        }
    }

    impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
        type Output = C;
        type Error = Infallible;
        type Future = future::Ready<Result<Self::Output, Self::Error>>;

        fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
            future::ready(Ok(stream))
        }
    }
}

/// The endpoint roles associated with a pending peer-to-peer connection.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum PendingPoint {
    /// The socket comes from a dialer.
    ///
    /// There is no single address associated with the Dialer of a pending
    /// connection. Addresses are dialed in parallel. Only once the first dial
    /// is successful is the address of the connection known.
    Dialer {
        /// Same as [`ConnectedPoint::Dialer`] `role_override`.
        role_override: Endpoint,
        port_use: PortUse,
    },
    /// The socket comes from a listener.
    Listener {
        /// Local connection address.
        local_addr: Multiaddr,
        /// Address used to send back data to the remote.
        send_back_addr: Multiaddr,
    },
}

impl From<ConnectedPoint> for PendingPoint {
    fn from(endpoint: ConnectedPoint) -> Self {
        match endpoint {
            ConnectedPoint::Dialer {
                role_override,
                port_use,
                ..
            } => PendingPoint::Dialer {
                role_override,
                port_use,
            },
            ConnectedPoint::Listener {
                local_addr,
                send_back_addr,
            } => PendingPoint::Listener {
                local_addr,
                send_back_addr,
            },
        }
    }
}
